Diffusing Away From GANs and Transformers
Investigating the math and code behind the current hype surrounding diffusion models and exploring their effectiveness, applicability, and drawbacks.
Created on June 7|Last edited on June 13
Comment
1. Setting the Background
The current boom of Machine Learning is primarily brought about by Deep Learning. Deep Learning is the method of representing data in a N-dimensional space (usually not humanly comprehensible) and learning to get some useful output from this representation, such as classifying a cat or a dog. This representation can not only classify but also generate images, bounding boxes, segmentation masks, etc.
In this article, we will discover the common generative models, their advantages, drawbacks, and how diffusion models, the new hype, change the game for generative networks. Just as a warning, there is going to be some mathematics involved here but I'll make sure to do some hand-holding for those who don’t prefer the complex equations.
To get started, let's explore the two most famous architectures used for generating images: Autoencoders and Generative Adversarial Networks.
1.1. Autoencoders and their variations
An autoencoder, as the name suggests is an encoder of information. A standard autoencoder model works by compressing the input into a fixed-size embedding. This embedding contains the information required to produce the output from the desired domain distribution.

A standard autoencoder has two networks: the encoder and the decoder. The aforementioned latent representation is generated by the encoder block which is usually a convolutional neural network. Using the information in this hidden representation, the decoder, also usually a convolutional neural network, can produce the desired output such as a denoised image, segmentation mask, etc.
A common problem with the original autoencoder was that the representation (z) is not constrained which can produce undesired results and destabilize training. This led to the introduction of variational autoencoders which are the most typical type of architectures of autoencoders found today. Variational autoencoders enforce a constraint on the mean and variance of the hidden representation which leads to less noisy images and better convergence.
In spite of these changes, autoencoders are still notorious for producing noisy images. This was changed by the Generative Adversarial Networks which hold the current state-of-the-art in high-quality image generation.
1.2. Generative Adversarial Networks
The generative adversarial network (or GAN in short) is the second most well-known architecture. GANs are famous for producing high fidelity outputs for various tasks such as image super-resolution, segmentation mask generation, image translation, and inpainting.

A generative adversarial network consists of two neural networks: the generator and the discriminator. The generator (usually a convolutional neural net) takes random noise as input and produces an image of the desired domain. The discriminator is tasked to differentiate if the generated image is real or not i.e. if the output belongs to the desired distribution or not.
These networks usually require large compute resources for training but the main drawback of these networks is mode collapse. This is a situation where the generator figures out how to exploit the discriminator's bias and produces outputs of a distribution that can successfully fool the discriminator. When analyzed, these generated images appear to have some common recurring patterns.
Many techniques have been employed to resolve these problems but the training uncertainties and expense usually deters research in this field. To improve on this, some researchers revisited the previously researched fundamentals of diffusion to propose a method of training models that reduced the requirements of large compute resources while at the same time producing better results than GANs in many cases.
2. What is Diffusion?
Diffusion models are based on the well researched concept of diffusion in physics.
In this context, diffusion is defined as the process by which an environment attempts to attain homogeneity by altering the potential gradient in response to the introduction of a new element. Diffusion as a notion is based on attaining uniformity in a system.

Diffusion of particles in an environment
But are the states of a diffusion process reversible? Can we identify these newly introduced particles in a homogeneous system? This is exactly what we try to do with diffusion models!
Consider that we have an image: we gradually add noise to the image in extremely small steps till we reach a stage () where the image is completely unrecognizable and becomes purely random noise.

Once the forward "noise addition" chain is complete, we use a deep learning model with some trainable parameters , to try and recover the image from the noise (denoising phase ) by estimating the noising chain at every timestep.
Diffusion Task:
Gradually add noise to the image in steps in the forward process and try to recover the original image from the noisy image at in the backward process by tracing the chain backwards.
💡
Diffusion was first introduced in 2015 [1] but was recently revived and developed by the researchers at Stanford and Google Brain. Diffusion models are typically classified into two types: continuous diffusion models and discrete diffusion models. In the forward chain, the former adds Gaussian noise to continuous signals, whilst the latter obfuscates discrete input tokens using a Markov Transition matrix. We'll look at the former in this post, understanding and implementing the equations from the DDPM paper [2] in JAX.
2.1. Forward Pass
The diffusion process is fixed to a Markov chain that gradually adds Gaussian noise to the data according to a variance schedule where
💡
Let us break this sentence down:
2.1.1. Markov Chain
A Markov chain is a chain of events or states that follow the Markov principle. Markov's principle states that the distribution of a variable at an arbitrary point in the chain is determined only by the distribution of the previous state of the variable.

This means that the state of is only dependent on . Similarly, the state of is only dependent on but since is dependent on , any arbitrary state in the chain is indirectly dependent on all the states that occur before it.
The Markov's principle derives that the probability of occurrence of a chain of events from to , given the first state is as follows:
The probability of a state given in our case is directly determined by the addition of noise since the amount of noise in the image at a given stage is only dependent on how much noise was previously existing.
2.1.2. Addition of Gaussian Noise
As discussed above, we will need to calculate the probability of for generating an image at a given timestamp . For this, we will need to sample some noise and incrementally add it to the image.
Noise obtained from a Gaussian distribution only depends on two factors: the mean and the standard deviation (or variance). By changing these two values, it is possible to generate an infinite number of distributions of noise, one of which can then be added to the image at every step.
This is where the variance schedule comes into play. For diffusion models, we fix the variance schedule as we move along the chain. The sampling of noise can be at a given state is defined as:
The above line basically says that we have to generate a Gaussian distribution () for by taking the value of as the mean and as the variance for that step. Combining this definition with the previous equation for , we can now sample the noise for any given step.
2.1.3. The Reparameterization Trick
For our training task, the model, given the timestamp, is responsible to remove the added noise from the image at that timestamp. To generate a noisy image for the said timestamp, we will need to iterate through the entire chain. This is extremely inefficient because pythonic loops are slow and given a large timestamp, the chain may take too long to iterate over.
To avoid this, we use a reparameterization trick. It uses an approximation to generate the noise at the required timestamp. This trick works because adding two Gaussians also results in a Gaussian. The reparameterized formula is given as below:
As compared to the previous equation, we can see that we have isolated the variance schedule and pre-calculated the cumulative product of this isolated variable . Using this equation, we can now directly sample the noisy image at any time step with just the original image.
2.2. Backward Pass
The backward pass aims to turn the noisy image into the desired domain distribution, whether it be for denoising, image super-resolution, or just about anything else!
2.2.1. Autoencoders are back?
For this task, we can use any model with a large enough capacity. Usually, papers tend to use autoencoders like U-Nets with global attention which are mathematically and experimentally proved to be performant for tasks such as generation and segmentation. The only difference between the U-Net model used for diffusion and a standard attention augmented U-Net is that additional timestamp information is integrated into the model as well. In general, models with increased width reach the desired sample quality faster than models with increased depth [4].

The above diagram compares a standard U-Net to the modified U-Net that integrates the information provided by the timestamp. The timestamp is first embedded into a N-dimensional vector and is then added to every layer in the model so that the model can learn the correlation between the noise and the timestamp and de-noise accordingly.
But wait! Do you notice something weird? The model takes in the timestamp and the noisy image as input and outputs noise ?
Yes! Commonly adopted diffusion models output noise but that doesn't mean you cannot directly output the image. The model's aim is to output the noise distribution it believes is present in the picture, and this is done only for the sake of convenience. If we output the noise, we can simplify the loss calculation which makes the process more understandable.
2.2.2. The Training Loop
The standard equation for the backward pass can be given as:
Here, we aim to generate the noise when going from a state to according to the mean and standard deviation distribution generated by the model. Researchers found that fixing the value of the variance to the value of helps the model converge better. Though this is still under active experimentation, we will go ahead and assume the output variance to be set as .
Now that we have defined what we need to do, let us define the loss function. The loss function used for diffusion models is derived from the ELBO loss commonly used with variational autoencoders. This loss defines a lower bound objective and a simplified version of the objective can be given as:
The descent function takes three inputs: the timestamp (), the original image (), and some randomly generated Gaussian noise that is to be added to the original image (). The model then generates the noise in the forward-propagation step that it thinks is added to the image and we calculate the mean squared error between the model output noise and the original noise. This loss value is then used to calculate the gradients and backpropagate through the autoencoder model.
2.2.3. The Inference Loop
After successfully completing the training process, we must define an inference loop that can generate new samples for us when provided with Gaussian noise. The general algorithm for sampling is given as follows:

Let us run through a loop of sampling. We first sample a random noise that we assume is the step image. Then, we simply loop backward from to where we sample the image according to the following formula:
This essentially indicates that we utilize the model's mean and set standard deviation to .
2.3. Implementing a Denoising Diffusion Model (Colab Notebook)
Now that we have skimmed over the theory of training diffusion models, let's get to implementing it.
- Defining the imports and initializing the run
import jaximport optaximport osimport mathimport wandbimport random as rimport numpy as npimport tensorflow as tfimport tensorflow_datasets as tfdsimport jax.numpy as jnpimport jax.random as randomimport flax.linen as nnfrom flax.training import train_stateimport matplotlib.pyplot as pltfrom typing import Callablefrom PIL import Imagefrom tqdm.notebook import tqdm# Set only 80% of memory to be accessible. This avoids OOM due to pre-allocation.%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.8# Defining some hyperparametersNUM_EPOCHS = 10BATCH_SIZE = 64NUM_STEPS_PER_EPOCH = 60000//64 # MNIST has 60,000 training samplesUSER = "" # Enter your W&B usernamePROJECT = "" # Enter your project name# Initializing W&B runwandb.init(entity=USER, project=PROJECT)
- Defining the Forward Pass
The forward pass algorithm can be written as:
- Define the total timesteps () for the chain
- Generate , and for every
- Generate noise according to
# Defining a constant value for Ttimesteps = 200# Defining beta for all t's in T stepsbeta = jnp.linspace(0.0001, 0.02, timesteps)# Defining alpha and its derivatives according to reparameterization trickalpha = 1 - betaalpha_bar = jnp.cumprod(alpha, 0)sqrt_alpha_bar = jnp.sqrt(alpha_bar)one_minus_sqrt_alpha_bar = jnp.sqrt(1 - alpha_bar)# Implement noising logic according to reparameterization trickdef forward_noising(key, x_0, t):noise = random.normal(key, x_0.shape)reshaped_sqrt_alpha_bar_t = jnp.reshape(jnp.take(sqrt_alpha_bar, t), (-1, 1, 1, 1))reshaped_one_minus_sqrt_alpha_bar_t = jnp.reshape(jnp.take(one_minus_sqrt_alpha_bar, t), (-1, 1, 1, 1))noisy_image = reshaped_sqrt_alpha_bar_t * x_0 + reshaped_one_minus_sqrt_alpha_bar_t * noisereturn noisy_image, noise# Let us visualize the output image at a few timestampsfig = plt.figure(figsize=(15, 30))for index, i in enumerate([10, 50, 100, 185]):noisy_im, noise = forward_noising(random.PRNGKey(0), jnp.expand_dims(sample_mnist, 0), jnp.array([i,]))plt.subplot(1, 4, index+1)plt.imshow(jnp.squeeze(jnp.squeeze(noisy_im, -1),0), cmap='gray')plt.show()

As we can see, the number gets progressively difficult to identify as T increases. At , the number is almost completely indistinguishable from the added noise.
- Defining the Model
We will be using an attention-augmented UNet architecture for our task. As discussed before, the model takes an additional time embedding to capture the correlation between the timestamp and the amount of noise added to the image.
Before we define the model itself, let us define how the time must be embedded into the model. We use the popular sinusoidal projection which is also commonly used in positional encodings in transformers. We project the time constant into a defined dimensional space (in our case, 128 dimensional) which we will integrate into the model later. Let us code this:
class SinusoidalEmbedding(nn.Module):dim: int = 32@nn.compactdef __call__(self, inputs):half_dim = self.dim // 2emb = math.log(10000) / (half_dim - 1)emb = jnp.exp(jnp.arange(half_dim) * -emb)emb = inputs[:, None] * emb[None, :]emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], -1)return embclass TimeEmbedding(nn.Module):dim: int = 32@nn.compactdef __call__(self, inputs):time_dim = self.dim * 4se = SinusoidalEmbedding(self.dim)(inputs)# Projecting the embedding into a 128 dim spacex = nn.Dense(time_dim)(se)x = nn.gelu(x)x = nn.Dense(time_dim)(x)return x
The first building block for the UNet is the attention mechanism. The attention that we will be using is the standard dot-product attention with eight heads.
class Attention(nn.Module):dim: intnum_heads: int = 8use_bias: bool = Falsekernel_init: Callable = nn.initializers.xavier_uniform()@nn.compactdef __call__(self, inputs):batch, h, w, channels = inputs.shapeinputs = inputs.reshape(batch, h*w, channels)batch, n, channels = inputs.shapescale = (self.dim // self.num_heads) ** -0.5qkv = nn.Dense(self.dim * 3, use_bias=self.use_bias, kernel_init=self.kernel_init)(inputs)qkv = jnp.reshape(qkv, (batch, n, 3, self.num_heads, channels // self.num_heads))qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = qkv[0], qkv[1], qkv[2]attention = (q @ jnp.swapaxes(k, -2, -1)) * scaleattention = nn.softmax(attention, axis=-1)x = (attention @ v).swapaxes(1, 2).reshape(batch, n, channels)x = nn.Dense(self.dim, kernel_init=nn.initializers.xavier_uniform())(x)x = jnp.reshape(x, (batch, int(x.shape[1]** 0.5), int(x.shape[1]** 0.5), -1))return x
Next, we will be defining the ResNet block. This ResNet block is only slightly different from the original ResNet block because it also incorporates a time embedding.
class Block(nn.Module):dim: int = 32groups: int = 8@nn.compactdef __call__(self, inputs):conv = nn.Conv(self.dim, (3, 3))(inputs)norm = nn.GroupNorm(num_groups=self.groups)(conv)activation = nn.silu(norm)return activationclass ResnetBlock(nn.Module):dim: int = 32groups: int = 8@nn.compactdef __call__(self, inputs, time_embed=None):x = Block(self.dim, self.groups)(inputs)if time_embed is not None:time_embed = nn.silu(time_embed)time_embed = nn.Dense(self.dim)(time_embed)x = jnp.expand_dims(jnp.expand_dims(time_embed, 1), 1) + xx = Block(self.dim, self.groups)(x)res_conv = nn.Conv(self.dim, (1, 1), padding="SAME")(inputs)return x + res_conv
Finally, we will implement the UNet. The UNet will have four upsampling and four downsampling blocks.
class UNet(nn.Module):dim: int = 8dim_scale_factor: tuple = (1, 2, 4, 8)num_groups: int = 8@nn.compactdef __call__(self, inputs):inputs, time = inputschannels = inputs.shape[-1]x = nn.Conv(self.dim // 3 * 2, (7, 7), padding=((3,3), (3,3)))(inputs)time_emb = TimeEmbedding(self.dim)(time)dims = [self.dim * i for i in self.dim_scale_factor]pre_downsampling = []# Downsampling phasefor index, dim in enumerate(dims):x = ResnetBlock(dim, self.num_groups)(x, time_emb)x = ResnetBlock(dim, self.num_groups)(x, time_emb)att = Attention(dim)(x)norm = nn.GroupNorm(self.num_groups)(att)x = norm + x# Saving this output for residual connection with the upsampling layerpre_downsampling.append(x)if index != len(dims) - 1:x = nn.Conv(dim, (4,4), (2,2))(x)# Middle blockx = ResnetBlock(dims[-1], self.num_groups)(x, time_emb)att = Attention(dim)(x)norm = nn.GroupNorm(self.num_groups)(att)x = norm + xx = ResnetBlock(dims[-1], self.num_groups)(x, time_emb)# Upsampling phasefor index, dim in enumerate(reversed(dims)):x = jnp.concatenate([pre_downsampling.pop(), x], -1)x = ResnetBlock(dim, self.num_groups)(x, time_emb)x = ResnetBlock(dim, self.num_groups)(x, time_emb)att = Attention(dim)(x)norm = nn.GroupNorm(self.num_groups)(att)x = norm + xif index != len(dims) - 1:x = nn.ConvTranspose(dim, (4,4), (2,2))(x)# Final ResNet block and output convolutional layerx = ResnetBlock(dim, self.num_groups)(x, time_emb)x = nn.Conv(channels, (1,1), padding="SAME")(x)return x
- Training Loop
Here, we define the training functions and loops in JAX.
According to the formula that we studied previously, the gradient descent step takes the model generated noisy image, the original noise, and the timestamp and returns the loss.
# Calculate the gradients and loss values for the specific timestamp@jax.jitdef apply_model(state, noisy_images, noise, timestamp):"""Computes gradients, loss and accuracy for a single batch."""def loss_fn(params):pred_noise = model.apply({'params': params}, [noisy_images, timestamp])loss = jnp.mean((noise - pred_noise) ** 2)return lossgrad_fn = jax.value_and_grad(loss_fn, has_aux=False)loss, grads = grad_fn(state.params)return grads, loss# Helper function for applying the gradients to the model@jax.jitdef update_model(state, grads):return state.apply_gradients(grads=grads)
The training step performs the following functions:
- Generate random PRNGKeys for generating the timestamps and noise
- Generate the noisy images
- Forward propagate on the UNet
- Update the model weights in the backward propagation process according to the calculated gradients
- Display loss at that particular step and return the current state and loss
# Define the training stepdef train_epoch(epoch_num, state, train_ds, batch_size, rng):epoch_loss = []num_steps_elapsed = epoch_num * NUM_STEPS_PER_EPOCHfor index, batch_images in enumerate(tqdm(train_ds)):rng, tsrng = random.split(rng)timestamps = random.randint(tsrng,shape=(batch_images.shape[0],),minval=0, maxval=timesteps)noisy_images, noise = forward_noising(rng, batch_images, timestamps)grads, loss = apply_model(state, noisy_images, noise, timestamps)state = update_model(state, grads)epoch_loss.append(loss)wandb.log({"train_loss": loss, 'step': num_steps_elapsed + (index + 1)})if index % 10 == 0:print(f"Loss at step {index}: ", loss)# Timestamps are not needed anymore. Saves some memory.del timestampstrain_loss = np.mean(epoch_loss)return state, train_loss
We will create two helper functions for loading the dataset and creating the training state for the model.
# Load and preprocess the MNIST datadef get_datasets():ds = tfds.load('mnist', as_supervised=True)train_ds, test_ds = ds['train'], ds['test']def preprocess(x, y):return tf.image.resize(tf.cast(x, tf.float32) / 127.5 - 1, (32, 32))train_ds = train_ds.map(preprocess, tf.data.AUTOTUNE)test_ds = test_ds.map(preprocess, tf.data.AUTOTUNE)train_ds = train_ds.shuffle(5000).batch(BATCH_SIZE)test_ds = test_ds.batch(BATCH_SIZE)return tfds.as_numpy(train_ds), tfds.as_numpy(test_ds)# Creating a train state for our Flax UNetdef create_train_state(rng):"""Creates initial `TrainState`."""params = model.init(rng, [jnp.ones([1, 32, 32, 1]), jnp.ones([1,])])['params']tx = optax.adam(1e-4)return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
Before we start the training process, we will define the logic for training which goes as follows:
- Generate a PRNGKey which will be used to initialize the weights
- Create a training state for our model using the helper function defined before
- Iterate over NUM_EPOCHS and for each epoch, call the train_epoch() function.
- Log the state at the end of the epoch for future reference (This is optional)
log_state = []def train(train_ds) -> train_state.TrainState:rng = jax.random.PRNGKey(0)rng, init_rng = jax.random.split(rng)state = create_train_state(init_rng)for epoch in range(1, NUM_EPOCHS + 1):rng, input_rng = jax.random.split(rng)state, train_loss = train_epoch(epoch, state, train_ds, BATCH_SIZE, input_rng)print("Training loss: ", train_loss)log_state.append(state)return state
All that is left to do now is load the dataset and call this train function:
train_ds, test_ds = get_datasets()# This function will return the final trained state after `NUM_EPOCHS` epochstrained_state = train(train_ds)
Let us log and monitor the training loss using Weights and Biases!
- Inference Loop
Now that we have trained our model, let us implement a helper function that can take randomly initialized noise and convert it into something that belongs to the input distribution and is more recognizable.
# This function defines the logic of getting x_t-1 given x_tdef backward_denoising(x_t, pred_noise, t):alpha_t = jnp.take(alpha, t)alpha_t_bar = jnp.take(alpha_bar, t)eps_coef = (1 - alpha_t) / (1 - alpha_t_bar) ** .5mean = 1 / (alpha_t ** 0.5) * (x_t - eps_coef * pred_noise)var = jnp.take(beta, t)eps = random.normal(key=random.PRNGKey(r.randint(1, 100)), shape=x_t.shape)return mean + (var ** 0.5) * eps
To use this function, we need a random noise, the model prediction for and the timestamp . Let us see the code for generating these
# Generating Gaussian noisex = random.normal(random.PRNGKey(42), (1, 32, 32, 1))img_list = []for i in range(0, timesteps):t = jnp.expand_dims(jnp.array(timesteps - i - 1, jnp.int32), 0)pred_noise = model.apply({'params': trained_state.params}, [x, t])x = backward_denoising(x, pred_noise, t)# Log the image after every 25 iterationsif i % 25 == 0:img_list.append(jnp.squeeze(jnp.squeeze(x, 0),-1))# Generate a GIF from the logged imagesimgs = (Image.fromarray((np.array(i) * 127.5) + 1) for i in img_list)img = next(imgs) # extract first image from iteratorimg.save(fp=f"output.gif", format='GIF', append_images=imgs,save_all=True, duration=200, loop=0)# Log the GIF to W&Bwandb.log({"Reconstruction-GIFs": wandb.Image(f"output.gif")})
Note:
The loop above is un-optimized and the image generation may take around 3-4 minutes depending on the accelerator.
💡
2.4. Using Diffusion Models with text prompts
The rapid growth that the prompt based image generation models like DALLE-2 [10] and Imagen [11] have seen can be credited to the rise of diffusion models. Prompt based generation is defined as the process of generating a viable, high quality image given a text description or class for which the image is to be generated. Some examples are given below:

These models are particularly impressive since they are capable of understanding visual context and producing high-resolution outputs. For example, consider the picture in the middle: for the model to produce such an image, the model must understand what a dog is, what a cat is, what a mirror is, along with the concept of reflection in a mirror. Apart from that, if you view the image closely, you will be able to see that the model also blurs the reflection slightly, giving it a realistic touch. This text-based generation is possible due to the integration of textual embeddings into the model.

Continuing with the same diagram as before, we can add a label or text prompt which can be embedded into the model either by using a pretrained NLP model's outputs or using a learned embedding. Though this method works well, it struggles to understand and generate the sequentiality of text based images such as sign boards or warnings.
This led to the introduction of classifier free guidance. In this method, the model samples the output image once with the text and once without.

The scaled difference is taken in direction of the text utilizing output
It is obvious that the text sampled vector, the former, is more effective than the latter so the authors [7] take the vectors of both the methods, take their difference, and scale this difference by a predefined scaling factor in the direction of the former. Training using this method helps the model grasp the missing sequentality and context.
Ongoing research attempts to find better methods to embed the textual information within the model to help the model with sequentiality and better image generation.
3. Improvements in Diffusion Models
Diffusion models are far from perfect! With every paper, diffusion models get better and better. Let us discuss two ways in which the original diffusion model used for denoising was improved.
3.1. Approximation
Usually, for these models to work, the value of must be set to a high number such as 1000 or 2000. This makes the inference longer and computationally heavier. Recent works have brought down this figure to just 25-50 steps [3][4] or even just 10 steps as in the case of vector quantized diffusion models [9]! This is done by using reparameterization tricks or by analytically skipping steps during the backward pass.
For understanding this reduction, let us revisit the equation used in the descent function:
Comparing this with the original equation we get:
Let where is treated as a hyperparameter. If we set as 1, we get the standard DDPM but if we set to 0, we can make the sampling process deterministic. This is exactly what the authors did in the Denoising Diffusion Implicit Models (DDIM) paper [3].
During the generation, we sample for where . With this technique of setting as 0, it is possible to get perceptually cleaner, higher quality images with 100 or lesser steps when the model was trained for over 1000!
3.2. Latent Diffusion
Another proposition to optimize the training is to reduce the size of the image on which diffusion is performed. This is called latent diffusion.

With latent diffusion, we can avoid processing large 512 512px or 1024 1024px images and can instead shrink these to a friendlier size of 2828px or 3232px. Diffusion is then applied to this downsampled representation whose output is then upsampled again. This, when combined with the approximation trick, can sample very high-resolution images comparatively quickly.
As research on diffusion models progresses, we will soon get to see more optimizations and tricks to increase image quality and reduce training overheads.
4. Advantages of Diffusion
Even though diffusion is a new domain for most practitioners, it has some clear advantages over previous methods.
- Less parameter tuning and stable training: As compared to GANs, diffusion models don't require training stabilization tricks.
- Faster training with optimized diffusion techniques: As discussed above, using techniques like latent diffusion along with approximation can lead to much faster inference times as well.
- High fidelity outputs: The perceptibility and quality of the generated image is superior to that of GANs and autoencoders.
5. References
- GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models (Alex Nichol et al, 2022)
- Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding (Chitwan Saharia et al, 2022)
Add a comment